Face Generation with DCGAN¶

This notebook demonstrates face generation using a DCGAN (Deep Convolutional GAN) trained on CelebA.

What makes this fun:

  • Train a GAN that generates realistic human faces
  • Watch the model learn facial features progressively
  • Fast training on GPU (~30 minutes for quality results)
  • Generate unlimited unique faces

Why DCGAN? Stable architecture with convolutional layers, batch normalization, and proven effectiveness for image generation.

In [1]:
# Installation
# !pip install torch torchvision matplotlib
In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
Using device: cuda

Part 1: Data Preparation¶

We'll use the CelebA dataset - 200k celebrity face images. The faces will be cropped, resized to 64x64, and normalized.

In [3]:
# Load CelebA dataset
image_size = 64
batch_size = 128

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

# Download CelebA dataset (this may take a few minutes the first time)
train_dataset = torchvision.datasets.CelebA(
    root='./data',
    split='train',
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

print(f'Training samples: {len(train_dataset):,}')
print(f'Batches per epoch: {len(train_loader):,}')
print(f'Image shape: {train_dataset[0][0].shape}')

# Visualize real samples
samples = next(iter(train_loader))[0][:64]
grid = make_grid(samples, nrow=8, normalize=True, value_range=(-1, 1))
plt.figure(figsize=(12, 12))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.title('Real CelebA Face Images', fontsize=16)
plt.axis('off')
plt.tight_layout()
plt.show()
Downloading...
From (original): https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM
From (redirected): https://drive.usercontent.google.com/download?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM&confirm=t&uuid=4f1cf598-a56a-45f9-8f98-c2bb644e80eb
To: /content/data/celeba/img_align_celeba.zip
100%|██████████| 1.44G/1.44G [00:12<00:00, 120MB/s]
Downloading...
From: https://drive.google.com/uc?id=0B7EVK8r0v71pblRyaVFSWGxPY0U
To: /content/data/celeba/list_attr_celeba.txt
100%|██████████| 26.7M/26.7M [00:00<00:00, 41.6MB/s]
Downloading...
From: https://drive.google.com/uc?id=1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS
To: /content/data/celeba/identity_CelebA.txt
100%|██████████| 3.42M/3.42M [00:00<00:00, 36.2MB/s]
Downloading...
From: https://drive.google.com/uc?id=0B7EVK8r0v71pbThiMVRxWXZ4dU0
To: /content/data/celeba/list_bbox_celeba.txt
100%|██████████| 6.08M/6.08M [00:00<00:00, 336MB/s]
Downloading...
From: https://drive.google.com/uc?id=0B7EVK8r0v71pd0FJY3Blby1HUTQ
To: /content/data/celeba/list_landmarks_align_celeba.txt
100%|██████████| 12.2M/12.2M [00:00<00:00, 363MB/s]
Downloading...
From: https://drive.google.com/uc?id=0B7EVK8r0v71pY0NSMzRuSXJEVkk
To: /content/data/celeba/list_eval_partition.txt
100%|██████████| 2.84M/2.84M [00:00<00:00, 290MB/s]
Training samples: 162,770
Batches per epoch: 1,272
Image shape: torch.Size([3, 64, 64])
No description has been provided for this image

Part 2: Build DCGAN¶

DCGAN uses deep convolutional layers with batch normalization for stable training. No fully connected layers!

In [4]:
class Generator(nn.Module):
    """
    DCGAN Generator for 64x64 RGB images.
    Architecture: latent vector -> 4x4 -> 8x8 -> 16x16 -> 32x32 -> 64x64
    """
    def __init__(self, latent_dim=100, ngf=64):
        super().__init__()
        self.latent_dim = latent_dim

        self.main = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # State: (ngf*8) x 4 x 4

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # State: (ngf*4) x 8 x 8

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # State: (ngf*2) x 16 x 16

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # State: (ngf) x 32 x 32

            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output: 3 x 64 x 64
        )

    def forward(self, noise):
        return self.main(noise)


class Discriminator(nn.Module):
    """
    DCGAN Discriminator for 64x64 RGB images.
    Architecture: 64x64 -> 32x32 -> 16x16 -> 8x8 -> 4x4 -> 1
    """
    def __init__(self, ndf=64):
        super().__init__()

        self.main = nn.Sequential(
            # Input: 3 x 64 x 64
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf) x 32 x 32

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf*2) x 16 x 16

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf*4) x 8 x 8

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf*8) x 4 x 4

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # Output: 1 x 1 x 1
        )

    def forward(self, image):
        return self.main(image).view(-1, 1)


# Initialize models
generator = Generator(latent_dim=100, ngf=64).to(device)
discriminator = Discriminator(ndf=64).to(device)

# Initialize weights (DCGAN paper recommendation)
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

generator.apply(weights_init)
discriminator.apply(weights_init)

print(f'Generator parameters: {sum(p.numel() for p in generator.parameters()):,}')
print(f'Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}')
Generator parameters: 3,576,704
Discriminator parameters: 2,765,568

Part 3: Training Setup¶

In [5]:
# Loss and optimizers
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Fixed noise for visualization (100 samples)
fixed_noise = torch.randn(64, generator.latent_dim, 1, 1).to(device)

print('Training setup complete!')
print(f'Fixed noise shape: {fixed_noise.shape}')
Training setup complete!
Fixed noise shape: torch.Size([64, 100, 1, 1])
In [6]:
# Training function
def train_epoch(generator, discriminator, loader, optimizer_g, optimizer_d, criterion, device):
    generator.train()
    discriminator.train()

    d_losses = []
    g_losses = []

    for real_images, _ in tqdm(loader, desc='Training'):
        batch_size = real_images.size(0)
        real_images = real_images.to(device)

        # Labels for real and fake
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # ============================================
        # Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
        # ============================================
        optimizer_d.zero_grad()

        # Real images
        real_output = discriminator(real_images)
        d_loss_real = criterion(real_output, real_labels)

        # Fake images
        noise = torch.randn(batch_size, generator.latent_dim, 1, 1).to(device)
        fake_images = generator(noise)
        fake_output = discriminator(fake_images.detach())
        d_loss_fake = criterion(fake_output, fake_labels)

        # Total discriminator loss
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()

        # ============================================
        # Train Generator: maximize log(D(G(z)))
        # ============================================
        optimizer_g.zero_grad()

        # Generate fake images
        noise = torch.randn(batch_size, generator.latent_dim, 1, 1).to(device)
        fake_images = generator(noise)
        fake_output = discriminator(fake_images)

        # Generator wants discriminator to think fakes are real
        g_loss = criterion(fake_output, real_labels)
        g_loss.backward()
        optimizer_g.step()

        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())

    return np.mean(d_losses), np.mean(g_losses)

Part 4: Train the GAN¶

Watch the generated faces improve progressively! Early epochs will show blurry faces, later epochs will show realistic features.

In [7]:
epochs = 20
sample_interval = 2  # Show samples every N epochs

history = {'d_loss': [], 'g_loss': []}

for epoch in range(epochs):
    print(f'\n=== Epoch {epoch+1}/{epochs} ===')

    d_loss, g_loss = train_epoch(generator, discriminator, train_loader,
                                  optimizer_g, optimizer_d, criterion, device)

    history['d_loss'].append(d_loss)
    history['g_loss'].append(g_loss)

    print(f'D Loss: {d_loss:.4f} | G Loss: {g_loss:.4f}')

    # Generate samples at intervals
    if (epoch + 1) % sample_interval == 0 or epoch == 0:
        generator.eval()
        with torch.no_grad():
            fake_images = generator(fixed_noise)

        grid = make_grid(fake_images, nrow=8, normalize=True, value_range=(-1, 1))
        plt.figure(figsize=(10, 10))
        plt.imshow(grid.permute(1, 2, 0).cpu())
        plt.title(f'Generated Faces - Epoch {epoch+1}', fontsize=16)
        plt.axis('off')
        plt.tight_layout()
        plt.show()

print('\nTraining complete!')
=== Epoch 1/20 ===
Training: 100%|██████████| 1272/1272 [01:34<00:00, 13.51it/s]
D Loss: 0.6442 | G Loss: 6.2513
No description has been provided for this image
=== Epoch 2/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.79it/s]
D Loss: 0.6499 | G Loss: 3.5977
No description has been provided for this image
=== Epoch 3/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.66it/s]
D Loss: 0.7304 | G Loss: 2.6784

=== Epoch 4/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.75it/s]
D Loss: 0.7767 | G Loss: 2.4050
No description has been provided for this image
=== Epoch 5/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.70it/s]
D Loss: 0.7661 | G Loss: 2.3691

=== Epoch 6/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.61it/s]
D Loss: 0.7316 | G Loss: 2.3640
No description has been provided for this image
=== Epoch 7/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.68it/s]
D Loss: 0.7087 | G Loss: 2.4317

=== Epoch 8/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.57it/s]
D Loss: 0.6845 | G Loss: 2.4593
No description has been provided for this image
=== Epoch 9/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.61it/s]
D Loss: 0.6667 | G Loss: 2.5262

=== Epoch 10/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.66it/s]
D Loss: 0.6172 | G Loss: 2.5825
No description has been provided for this image
=== Epoch 11/20 ===
Training: 100%|██████████| 1272/1272 [01:34<00:00, 13.44it/s]
D Loss: 0.5901 | G Loss: 2.7086

=== Epoch 12/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.70it/s]
D Loss: 0.5598 | G Loss: 2.8499
No description has been provided for this image
=== Epoch 13/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.57it/s]
D Loss: 0.4907 | G Loss: 3.0304

=== Epoch 14/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.64it/s]
D Loss: 0.4463 | G Loss: 3.2286
No description has been provided for this image
=== Epoch 15/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.55it/s]
D Loss: 0.4580 | G Loss: 3.2955

=== Epoch 16/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.60it/s]
D Loss: 0.4306 | G Loss: 3.4094
No description has been provided for this image
=== Epoch 17/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.67it/s]
D Loss: 0.4157 | G Loss: 3.5188

=== Epoch 18/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.69it/s]
D Loss: 0.4293 | G Loss: 3.5587
No description has been provided for this image
=== Epoch 19/20 ===
Training: 100%|██████████| 1272/1272 [01:32<00:00, 13.71it/s]
D Loss: 0.3659 | G Loss: 3.7131

=== Epoch 20/20 ===
Training: 100%|██████████| 1272/1272 [01:33<00:00, 13.66it/s]
D Loss: 0.3593 | G Loss: 3.6743
No description has been provided for this image
Training complete!
In [8]:
# Plot training curves
plt.figure(figsize=(10, 5))
plt.plot(history['d_loss'], label='Discriminator Loss', linewidth=2)
plt.plot(history['g_loss'], label='Generator Loss', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Training Loss Over Time', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()
No description has been provided for this image

Part 5: Generate More Faces¶

Generate new random faces on demand!

In [9]:
def generate_faces(num_samples=16):
    """
    Generate random faces.

    Args:
        num_samples: Number of faces to generate
    """
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_samples, generator.latent_dim, 1, 1).to(device)
        generated = generator(noise)

    grid = make_grid(generated, nrow=4, normalize=True, value_range=(-1, 1))
    plt.figure(figsize=(8, 8))
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.title(f'Generated Faces', fontsize=16)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

# Generate multiple batches
print('Generating random faces...')
for i in range(3):
    print(f'\nBatch {i+1}:')
    generate_faces(num_samples=16)
Generating random faces...

Batch 1:
No description has been provided for this image
Batch 2:
No description has been provided for this image
Batch 3:
No description has been provided for this image

Part 6: Generate High-Resolution Grid¶

Create a large grid showing the variety of generated faces.

In [10]:
# Generate a large grid of faces
generator.eval()
num_faces = 64

with torch.no_grad():
    noise = torch.randn(num_faces, generator.latent_dim, 1, 1).to(device)
    generated_faces = generator(noise)

grid = make_grid(generated_faces, nrow=8, normalize=True, value_range=(-1, 1))
plt.figure(figsize=(15, 15))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.title('Generated Face Gallery (64 unique faces)', fontsize=18)
plt.axis('off')
plt.tight_layout()
plt.show()
No description has been provided for this image

Part 7: Latent Space Exploration¶

Interpolate between two random points in latent space for the same digit.

In [11]:
def interpolate_latent(num_steps=10):
    """
    Interpolate between two random latent vectors to show smooth transitions.
    """
    generator.eval()

    # Two random starting points
    z1 = torch.randn(1, generator.latent_dim, 1, 1).to(device)
    z2 = torch.randn(1, generator.latent_dim, 1, 1).to(device)

    interpolations = []
    with torch.no_grad():
        for alpha in torch.linspace(0, 1, num_steps):
            z = (1 - alpha) * z1 + alpha * z2
            img = generator(z)
            interpolations.append(img)

    interpolations = torch.cat(interpolations)
    grid = make_grid(interpolations, nrow=num_steps, normalize=True, value_range=(-1, 1))

    plt.figure(figsize=(15, 3))
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.title(f'Latent Space Interpolation (smooth transitions)', fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

# Show multiple interpolations
print('Latent space interpolation - watch faces morph smoothly!')
for i in range(3):
    print(f'\nInterpolation {i+1}:')
    interpolate_latent(num_steps=10)
Latent space interpolation - watch faces morph smoothly!

Interpolation 1:
No description has been provided for this image
Interpolation 2:
No description has been provided for this image
Interpolation 3:
No description has been provided for this image